import numpy as np
import mindspore as ms
from mind3d.models.blocks.kps_module import KPSModule


def test_kps(k):
    """
    test kps module of groupfree-3d
    """
    xyz = ms.Tensor(np.ones((32, 2000, 3)), ms.float32)
    features = ms.Tensor(np.ones((32, 3, 2000)), ms.float32)
    kps_module = KPSModule(num_proposal=k)
    new_xyz, new_features, _, _ = kps_module(xyz, features)
    assert new_xyz.shape == (32, k, 3), 'output shape not match'
    assert new_features == (32, 3, k), 'output shape not match'
    print("################################")
    print("kps test passed!")
    print("################################")
